import os
from multiprocessing import Pool

import nltk
from utils.metrics.Metrics import Metrics
from nltk import ngrams
import tensorflow as tf
from models.seqgan.SeqganReward import Reward
from utils.utils import *

class DualityGap(Metrics):
    def __init__(self, generator, discriminator, session, dis_data_loader, steps=20, name="seqgan"):
        super().__init__()
        self.name = 'DualityGap'

        self.batch_size = 64
        self.generate_num = 128
        self.oracle_file = 'save/oracle_dg_'+name+'.txt'
        self.generator_file = 'save/generator_dg_'+name+'.txt'
        self.test_file = 'save/test_file_dg_'+name+'.txt'
        self.session = session
        self.generator = generator
        self.discriminator = discriminator
        self.dis_data_loader = dis_data_loader
        self.steps = steps

    def get_name(self):
        return self.name

    def get_score(self, ignore=False):
        if ignore:
            return 0
        # t_vars = tf.global_variables()
        # d_vars_tmp = [var for var in t_vars if
        #               'disc_tmp' in var.name and "Adam" not in var.name]
        # # d_vars_tmp = [var for var in t_vars if 'discWorst' in var.name and "Adam" not in var.name]
        # d_vars_0 = [var for var in t_vars if
        #             'discriminator' in var.name and "Adam" not in var.name]
        # g_vars_tmp = [var for var in t_vars if
        #               'gen_tmp' in var.name and "Adam" not in var.name]
        # g_vars_0 = [var for var in t_vars if
        #             'generator_1' in var.name and "Adam" not in var.name]
        # for vv in g_vars_tmp:
        #     if "gen_tmp/Variable_14" in vv.name:
        #         v = self.session.run(vv)[0]
        # for vv in g_vars_0:
        #     if "generator_1/Variable_14" in vv.name:
        #         v1 = self.session.run(vv)[0]
        # print ("Before")
        # print (v)
        # print (v1)

        # save params
        self.session.run(self.set_current_to_tmp())
        # for vv in g_vars_tmp:
        #     if "gen_tmp/Variable_14" in vv.name:
        #         v = self.session.run(vv)[0]
        # for vv in g_vars_0:
        #     if "generator_1/Variable_14" in vv.name:
        #         v1 = self.session.run(vv)[0]
        # print ("After swap")
        # print (v)
        # print (v1)

        # train worst_g
        self.reward = Reward(self.generator, .8)
        for epoch in range(self.steps):
            # print('epoch:' + str(epoch))
            for index in range(1):
                samples = self.generator.generate(self.session)
                rewards = self.reward.get_reward(self.session, samples, 16, self.discriminator)
                feed = {
                    self.generator.x: samples,
                    self.generator.rewards: rewards
                }
                _, l = self.session.run([self.generator.g_updates, self.generator.g_loss], feed_dict=feed)
                # print ("Training worst G")
                # print (l)

        generate_samples(self.session, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file, self.generator_file)
        for _ in range(1):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.input_x: x_batch,
                self.discriminator.input_y: y_batch,
            }
            worst_maxmin = self.session.run(self.discriminator.d_loss_no_reg, feed)

        # for vv in g_vars_tmp:
        #     if "gen_tmp/Variable_14" in vv.name:
        #         v = self.session.run(vv)[0]
        # for vv in g_vars_0:
        #     if "generator_1/Variable_14" in vv.name:
        #         v1 = self.session.run(vv)[0]
        # print ("After train G, second should change")
        # print (v)
        # print (v1)

        self.session.run(self.set_current_to_tmp(reverse=True))
        # for vv in g_vars_tmp:
        #     if "gen_tmp/Variable_14" in vv.name:
        #         v = self.session.run(vv)[0]
        # for vv in g_vars_0:
        #     if "generator_1/Variable_14" in vv.name:
        #         v1 = self.session.run(vv)[0]
        # print ("After reverse swap")
        # print (v)
        # print (v1)

        generate_samples(self.session, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file, self.generator_file)
        for _ in range(self.steps):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.input_x: x_batch,
                self.discriminator.input_y: y_batch,
            }
            loss, _ = self.session.run([self.discriminator.d_loss, self.discriminator.train_op], feed)
            # print ("Training worst D")
            # print (loss)

        generate_samples(self.session, self.generator, self.batch_size, self.generate_num, self.generator_file)
        self.dis_data_loader.load_train_data(self.oracle_file, self.generator_file)
        for _ in range(1):
            self.dis_data_loader.next_batch()
            x_batch, y_batch = self.dis_data_loader.next_batch()
            feed = {
                self.discriminator.input_x: x_batch,
                self.discriminator.input_y: y_batch,
            }
            worst_minmax = self.session.run(self.discriminator.d_loss_no_reg, feed)

        self.session.run(self.set_current_to_tmp(reverse=True))
        # for vv in g_vars_tmp:
        #     if "gen_tmp/Variable_14" in vv.name:
        #         v = self.session.run(vv)[0]
        # for vv in g_vars_0:
        #     if "generator_1/Variable_14" in vv.name:
        #         v1 = self.session.run(vv)[0]
        # print ("After reverse swap 2")
        # print (v)
        # print (v1)
        print ("the minmax is:")
        print (-worst_minmax)

        return (-worst_minmax)-(-worst_maxmin)

    def set_current_to_tmp(self, reverse=False):
        curr_to_tmp = []
        t_vars = tf.global_variables()
        d_vars_tmp = [var for var in t_vars if
                      'disc_tmp' in var.name and "Adam" not in var.name]
        # d_vars_tmp = [var for var in t_vars if 'discWorst' in var.name and "Adam" not in var.name]
        d_vars_0 = [var for var in t_vars if
                    'discriminator' in var.name and "Adam" not in var.name]
        g_vars_tmp = [var for var in t_vars if
                      'gen_tmp' in var.name and "Adam" not in var.name]
        g_vars_0 = [var for var in t_vars if
                    'generator_1' in var.name and "Adam" not in var.name]


        # print("G worst optimized vars")
        # for d1 in g_vars_tmp:
        #   print(d1.name)
        # print("G all")
        # for d1 in g_vars_0:
        #   print(d1.name)


        if (not reverse):
            for j in range(0, len(d_vars_0)):
                # print ("D0: ")
                # print (d_vars_0[j])
                # print ("D1: ")
                # print (d_vars_tmp[j])
                curr_to_tmp.append(d_vars_tmp[j].assign(d_vars_0[j]))


            for j in range(0, len(g_vars_0)):
                # print ("G0: ")
                # print (g_vars_0[j])
                # print ("G1: ")
                # print (g_vars_tmp[j])
                curr_to_tmp.append(g_vars_tmp[j].assign(g_vars_0[j]))
        else:
            for j in range(0, len(d_vars_tmp)):
                curr_to_tmp.append(d_vars_0[j].assign(d_vars_tmp[j]))

            for j in range(0, len(g_vars_tmp)):
                curr_to_tmp.append(g_vars_0[j].assign(g_vars_tmp[j]))

        # for j in range(0, len(g_vars_0)):
        #   logging.info("add to: "+str(g_vars_tmp[j].name))
        #   logging.info("add from: "+str(g_vars_0[j].name))
        #   leng = len(d_vars_tmp)
        #   len0 = len(d_vars_0)
        #   logging.info("length G: " + str(leng))
        #   logging.info("length G0: " + str(len0))

        self.current_to_tmp = tf.group(*curr_to_tmp)
        return self.current_to_tmp
